package lc;

import java.util.Arrays;

public class lc556 {
    public static void main(String[] args) {
        lc556 l = new lc556();
        l.nextGreaterElement(230241);
    }
    public int nextGreaterElement(int n) {
        //
        int[] arr = new int[32];
        int idx = 0;
        while (n > 0){
            arr[idx++] = n % 10;
            n /= 10;
        }
        int l = 1, r = 32, m = 0;
        while (l < r){
            if (arr[m] > arr[l++]){
                r = l;
                m++;
                l = m + 1;
            }
        }

        Arrays.sort(arr,0,l);
        int sum = 0;
        for(int j = idx - 1; j >= 0; j--){
            sum = sum * 10 + arr[j];
        }
        return sum;
    }

    public void swap(int[] arr, int i, int j){
        int tmp = arr[i];
        arr[i] = arr[j];
        arr[j] = tmp;
    }
}
